# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Conv2dBnFoldQuant."""
from __future__ import absolute_import

import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import QuantDtype
import mindspore.context as context
from mindspore.nn.cell import Cell
from mindspore.ops.operations import _quant_ops as Q
from mindspore_gs.validator import Validator, twice
from ...quantization.simulated_quantization.combined import Conv2dBn
from .fake_quant_with_min_max_observer import quant_config_default, QuantConfig
from .batchnorm_fold_cell import BatchNormFoldCell


class Conv2dBnFoldQuant(Cell):
    r"""
    2D convolution with Batch Normalization operation folded construct.

    This part is a more detailed overview of Conv2d operation. For more details about Quantization,
    please refer to the implementation of class of `FakeQuantWithMinMaxObserver`,
    :class:`FakeQuantWithMinMaxObserver`.

    .. math::
        y = x\times w+  b

        w_{q}=quant(\frac{w}{\sqrt{Var[y]+\epsilon}}*\gamma )

        y_{out}= w_{q}\times x+\frac{b-E[y]}{\sqrt{Var[y]+\epsilon}}*\gamma +\beta

    where :math:`quant` is the continuous execution of quant and dequant. Two convolution
    and Batch Normalization operation are used here, the purpose of the first convolution and Batch Normalization
    is to count the mean `E[y]` and variance `Var[y]` of current batch output for quantization.

    Args:
        in_channels (int): The number of input channel :math:`C_{in}`.
        out_channels (int): The number of output channel :math:`C_{out}`.
        kernel_size (Union[int, tuple[int]]): Specifies the height and width of the 2D convolution window.
        stride (Union[int, tuple[int]]): Specifies stride for all spatial dimensions with the same value. Default: 1.
        pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
        padding (Union[int, tuple[int]]): Implicit paddings on both sides of the `x`. Default: 0.
        dilation (Union[int, tuple[int]]): Specifies the dilation rate to use for dilated convolution. Default: 1.
        group (int): Splits filter into groups, `in_channels` and `out_channels` must be
            divisible by the number of groups. Default: 1.
        eps (float): Parameters for Batch Normalization. Default: 1e-5.
        momentum (float): Parameters for Batch Normalization op. Default: 0.997.
        has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
        weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
            convolution kernel. Default: 'normal'.
        bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
            bias vector. Default: 'zeros'.
        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
            beta vector. Default: 'zeros'.
        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
            gamma vector. Default: 'ones'.
        mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
            mean vector. Default: 'zeros'.
        var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
            variance vector. Default: 'ones'.
        fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True.
        quant_config (QuantConfig): Configures the types of quant observer and quant settings of weight and
            activation. Note that, QuantConfig is a special namedtuple, which is designed for quantization
            and can be generated by :func:`mindspore.compression.quant.create_quant_config` method.
            Default: QuantConfig with both items set to default :class:`FakeQuantWithMinMaxObserver`.
        quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
        freeze_bn (int): The quantization freeze Batch Normalization op is according to the global step.
            Default: 100000.

    Inputs:
        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.

    Outputs:
        Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

    Raises:
        TypeError: If `in_channels`, `out_channels` or `group` is not an int.
        TypeError: If `kernel_size`, `stride`, `padding` or `dilation` is neither an int nor a tuple.
        TypeError: If `has_bias` or `fake` is not a bool.
        ValueError: If `in_channels`, `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
        ValueError: If `padding` is less than 0.
        ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.

    Supported Platforms:
        ``Ascend`` ``GPU``

    Examples:
        >>> import numpy as np
        >>> import mindspore
        >>> from mindspore.compression import quant
        >>> from mindspore import Tensor, nn
        >>> qconfig = quant.create_quant_config()
        >>> conv2d_bnfold = nn.Conv2dBnFoldQuant(1, 1, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
        ...                                      weight_init="ones", quant_config=qconfig)
        >>> x = Tensor(np.array([[[[1, 0, 3], [1, 4, 7], [2, 5, 2]]]]), mindspore.float32)
        >>> result = conv2d_bnfold(x)
        >>> print(result)
        [[[[5.9296875 13.8359375]
           [11.859375 17.78125]]]]
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 pad_mode='same',
                 padding=0,
                 dilation=1,
                 group=1,
                 eps=1e-5,
                 momentum=0.997,
                 has_bias=False,
                 weight_init='normal',
                 bias_init='zeros',
                 beta_init='zeros',
                 gamma_init='ones',
                 mean_init='zeros',
                 var_init='ones',
                 fake=True,
                 quant_config=quant_config_default,
                 quant_dtype=QuantDtype.INT8,
                 freeze_bn=100000):
        """Initialize Conv2dBnFoldQuant layer"""
        super(Conv2dBnFoldQuant, self).__init__()
        if context.get_context('device_target') == "CPU":
            raise ValueError(f"For '{self.cls_name}', only the 'Ascend' and 'GPU' platforms"
                             f" are supported, but got {context.get_context('device_target')}.")
        self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
        self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
        self.kernel_size = twice(kernel_size)
        self.stride = twice(stride)
        self.dilation = twice(dilation)
        for kernel_size_elem in self.kernel_size:
            Validator.check_positive_int(kernel_size_elem, 'kernel_size item', self.cls_name)
        for stride_elem in self.stride:
            Validator.check_positive_int(stride_elem, 'stride item', self.cls_name)
        for dilation_elem in self.dilation:
            Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
        if pad_mode not in ('valid', 'same', 'pad'):
            raise ValueError(f"For '{self.cls_name}', the 'pad_mode' must be one of values in "
                             f"('valid', 'same', 'pad'), but got {pad_mode}.")
        self.pad_mode = pad_mode
        if isinstance(padding, int):
            Validator.check_non_negative_int(padding, 'padding', self.cls_name)
            self.padding = padding
        elif isinstance(padding, tuple):
            for pad in padding:
                Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
            self.padding = padding
        else:
            raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), "
                            f"but got {type(padding).__name__}!")
        self.group = Validator.check_positive_int(group, "group", self.cls_name)
        self.eps = eps
        self.momentum = momentum
        self.has_bias = has_bias
        self.freeze_bn = freeze_bn
        self.fake = Validator.check_bool(fake, "fake", self.cls_name)
        self.quant_config = quant_config
        self.quant_dtype = quant_dtype
        self.is_gpu = context.get_context('device_target') == "GPU"

        # initialize convolution op and Parameter
        self.conv = P.Conv2D(out_channel=out_channels,
                             kernel_size=self.kernel_size,
                             pad_mode=pad_mode,
                             pad=padding,
                             stride=self.stride,
                             dilation=self.dilation,
                             group=group)
        weight_shape = [out_channels, in_channels // group, *self.kernel_size]
        channel_axis = 0
        self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
        self.bias_add = P.BiasAdd()
        self.bias = None
        if Validator.check_bool(has_bias, "has_bias", self.cls_name):
            self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')

        # initialize BatchNorm Parameter
        self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma')
        self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta')
        self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False)
        self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance',
                                         requires_grad=False)

        # initialize fake ops
        self.fake_quant_weight = quant_config.weight(channel_axis=channel_axis,
                                                     num_channels=out_channels)
        self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
        self.correct_mul = Q.CorrectionMul(channel_axis)
        if context.get_context('device_target') == "Ascend":
            self.batchnorm_fold2_train = Q.BatchNormFold2D(freeze_bn=freeze_bn)
            self.batchnorm_fold2_infer = Q.BatchNormFold2D(freeze_bn=0)
        elif context.get_context('device_target') == "GPU":
            self.batchnorm_fold2_train = Q.BatchNormFold2(freeze_bn=freeze_bn)
            self.batchnorm_fold2_infer = Q.BatchNormFold2(freeze_bn=0)
        self.step = Parameter(initializer('normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
        self.one = Tensor(1, mstype.int32)
        self.assignadd = P.AssignAdd()

    @classmethod
    def from_convbn(cls, convbn: Conv2dBn, quant_config: QuantConfig, extra_args: dict):
        """
        A class method to create `Conv2dBnFoldQuant` from a `Conv2dBn`
        """
        conv_quant = cls(in_channels=convbn.conv.in_channels,
                         out_channels=convbn.conv.out_channels,
                         kernel_size=convbn.conv.kernel_size,
                         stride=convbn.conv.stride,
                         pad_mode=convbn.conv.pad_mode,
                         padding=convbn.conv.padding,
                         dilation=convbn.conv.dilation,
                         group=convbn.conv.group,
                         eps=convbn.batchnorm.eps,
                         momentum=convbn.batchnorm.momentum,
                         has_bias=convbn.conv.has_bias,
                         bias_init=convbn.conv.bias_init,
                         weight_init=convbn.conv.weight_init,
                         quant_config=quant_config,
                         fake=True,
                         freeze_bn=extra_args["freeze_bn"])
        conv_quant.gamma = convbn.batchnorm.gamma
        conv_quant.beta = convbn.batchnorm.beta
        conv_quant.moving_mean = convbn.batchnorm.moving_mean
        conv_quant.moving_variance = convbn.batchnorm.moving_variance
        conv_quant.weight = convbn.conv.weight
        if convbn.conv.has_bias:
            conv_quant.bias = convbn.conv.bias
        return conv_quant

    def extend_repr(self):
        """Display instance object as string."""
        s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
            'pad_mode={}, padding={}, dilation={}, group={}, ' \
            'fake={}, freeze_bn={}, momentum={}'.format(self.in_channels, self.out_channels, self.kernel_size,
                                                        self.stride, self.pad_mode, self.padding, self.dilation,
                                                        self.group, self.fake, self.freeze_bn, self.momentum)
        return s

    def construct(self, x):
        """construct."""
        out_conv = self.conv(x, self.weight)
        if self.has_bias:
            out_conv = self.bias_add(out_conv, self.bias)
        # BN fold1
        batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv,
                                                                               self.moving_mean,
                                                                               self.moving_variance,
                                                                               self.step)
        # fake weight
        weight = self.correct_mul(self.weight, self.gamma, running_std)
        if self.fake:
            weight = self.fake_quant_weight(weight)
        out = self.conv(x, weight)
        if self.has_bias:
            out = self.bias_add(out, self.bias)
        # BN fold2
        if self.is_gpu:
            if self.training:
                out = self.batchnorm_fold2_train(out, self.beta, self.gamma,
                                                 batch_std, batch_mean, running_std, running_mean, self.step)
                self.assignadd(self.step, self.one)
            else:
                out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
                                                 batch_std, batch_mean, running_std, running_mean, self.step)
        else:
            if self.training:
                out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
                self.assignadd(self.step, self.one)
            else:
                out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, running_std, running_mean, running_std)
        return out
